-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[Attention] Use sparse prefill kernel for fp8 kv-cache in DeepSeek-v3.2 #27532
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Attention] Use sparse prefill kernel for fp8 kv-cache in DeepSeek-v3.2 #27532
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
vllm/v1/worker/workspace.py
Outdated
| def get(self, spec: "WorkspaceSpec") -> torch.Tensor: | ||
| """Get a workspace tensor for the given spec. | ||
| Args: | ||
| spec: The workspace specification. | ||
| Returns: | ||
| A tensor view into the workspace buffer with the requested shape and dtype. | ||
| """ | ||
| num_bytes = spec.num_bytes() | ||
| current_workspace = self._ensure_workspace_size(num_bytes, spec.name) | ||
| return current_workspace[:num_bytes].view(spec.dtype).reshape(spec.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Allocating workspaces fails due to invalid
view call
WorkspaceManager.get reinterprets the byte buffer with current_workspace[:num_bytes].view(spec.dtype) but Tensor.view only accepts a shape, not a dtype. Passing a torch.dtype raises TypeError: 'torch.dtype' object cannot be interpreted as an integer, so every call to reserve/get will crash before returning a workspace. The manager needs to reshape using a size tuple and cast with view(dtype) via reinterpret_cast semantics (e.g. view(-1).view(spec.dtype) or view(dtype).reshape).
Useful? React with 👍 / 👎.
| # Process decode tokens | ||
| if num_decode_tokens > 0: | ||
| attn_out = self._forward_fp8_kv( | ||
| q[:num_decode_tokens], | ||
| kv_cache, | ||
| topk_indices_global[:num_decode_tokens], | ||
| attn_metadata, | ||
| ) | ||
|
|
||
| if num_prefill_tokens > 0: | ||
| decode_attn_out = attn_out | ||
| attn_out = q.new_empty( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefill-only batches reference
attn_out before initialization
In the fp8 path of FlashMLASparseImpl.forward, attn_out is only assigned inside the if num_decode_tokens > 0 branch. The subsequent if num_prefill_tokens > 0 branch unconditionally reads decode_attn_out = attn_out, which raises UnboundLocalError whenever a batch contains only prefill tokens. Prefill batches are common during initial context ingestion, so this path will always fail until attn_out is initialized for the prefill case.
Useful? React with 👍 / 👎.
| if num_prefill_tokens > 0: | ||
| decode_attn_out = attn_out | ||
| attn_out = q.new_empty( | ||
| (num_actual_toks, self.num_heads, self.kv_lora_rank), | ||
| dtype=q.dtype, | ||
| device=q.device, | ||
| ) | ||
| attn_out[:num_prefill_tokens] = decode_attn_out[:num_prefill_tokens] | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Decode outputs stored into prefill slots
When both decode and prefill tokens exist, the fp8 path copies decode attention results with attn_out[:num_prefill_tokens] = decode_attn_out[:num_prefill_tokens]. Decode tokens occupy the first num_decode_tokens entries, so this writes them into the wrong slice and fails whenever num_prefill_tokens > num_decode_tokens because the right-hand side is shorter than the target. The assignment should use num_decode_tokens to preserve decode outputs and avoid size mismatches.
Useful? React with 👍 / 👎.
7a3b6b6 to
39ba79c
Compare
| None, # Pass None to avoid using sampled token counts | ||
| ) | ||
|
|
||
| current_workspace_manager().get_simultaneous( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have to do these allocation during model execution? Is it possible to setup the memory buffer before real execution to reduce the runtime overhead?
For example, I'm thinking of
- during model init, call current_workspace_manager().get_simultaneous() to tell the workspace manager the max possible size the model may used
- lock the memory space
- during profile run, allocate the memory and save to
self.xxxlikeself.workspace13 = workspace13; self.workspace2 = workspace2 - during model execution, just use the
self.workspace13
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ya we should do something like this the only complication currently is when self.fused_experts.supports_chunking() == False (i.e. PPLX or DeepEP LL) then we need profile run to know the shape of the workspaces because in that case the profile run actually mirrors the worst case scenario (hence gating this logic).
This code is meant to a straight refactor of #27426 (once that lands) to use the new workspace manager so im partial to leaving this optimization to a future PR if you are cool with it. I am trying to learn more about the MoE chunking code in-order to propose a broader UX refactor there since I find it confusing currently and I think this optimization could be part of that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah works for me. Can you do some benchmark to ensure no perf regression?
| ((total_seq_lens, 4), torch.uint8), | ||
| ) | ||
|
|
||
| return sparse_attn_indexer_fake( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should these lines in sparse_attn_indexer_fake be removed? I wrote them to mimic the activation memory during profile run.
_flattened_kv = torch.empty(
[total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8
)
_k_fp8 = _flattened_kv[..., :head_dim].view(torch.float8_e4m3fn).contiguous()
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
theyre back 👍 (not sure when haha; I guess I must have done it and just not pushed it)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to say that if you choose to reserve the buffers of shape (total_seq_lens, head_dim) and (total_seq_lens, 4) in the workspace, you don't need to run the following in sparse_attn_indexer_fake
_flattened_kv = torch.empty(
[total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8
)
to also reserve memory for these two tensors in activation memory.
39ba79c to
4a49fc9
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Lucas Wilkinson <[email protected]> cleanup Signed-off-by: Lucas Wilkinson <[email protected]> fix Signed-off-by: Lucas Wilkinson <[email protected]> fix Signed-off-by: Lucas Wilkinson <[email protected]> fix Signed-off-by: Lucas Wilkinson <[email protected]> clean-up revert to triton Signed-off-by: Lucas Wilkinson <[email protected]> cleanup Signed-off-by: Lucas Wilkinson <[email protected]> cleanup Signed-off-by: Lucas Wilkinson <[email protected]> cleanup Signed-off-by: Lucas Wilkinson <[email protected]> cleanup Signed-off-by: Lucas Wilkinson <[email protected]> keep Signed-off-by: Lucas Wilkinson <[email protected]> fix Signed-off-by: Lucas Wilkinson <[email protected]> cleanup Signed-off-by: Lucas Wilkinson <[email protected]> cleanup Signed-off-by: Lucas Wilkinson <[email protected]> cleanup Signed-off-by: Lucas Wilkinson <[email protected]> cleanup Signed-off-by: Lucas Wilkinson <[email protected]> cleanup Signed-off-by: Lucas Wilkinson <[email protected]> review comments Signed-off-by: Lucas Wilkinson <[email protected]> fixed Signed-off-by: Lucas Wilkinson <[email protected]> cleanup Signed-off-by: Lucas Wilkinson <[email protected]> minor optimization Signed-off-by: Lucas Wilkinson <[email protected]> remove get Signed-off-by: Lucas Wilkinson <[email protected]> clean up Signed-off-by: Lucas Wilkinson <[email protected]> fix Signed-off-by: Lucas Wilkinson <[email protected]>
a3f6647 to
f665def
Compare
Signed-off-by: Lucas Wilkinson <[email protected]>
When doing prefill up-convert the kv-cache from fp8 to bf16 and call the bf16 prefill kernel instead of the decode kernel. This PR introduce global workspace management to have the bf16 workspace overlap with the MoE workspace buffers.